{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## W4A16 Quantization and Compression ##\n",
    "\n",
    "Using compressed-tensors, we can compress a quantized model to store it more efficiently on disk.\n",
    "\n",
    "In this example, we run post-training quantization (PTQ) to quantize the weights of an example model to 4 bits. We then save a compressed version of the model on disk by packing each group of eight 4-bit weights into a single int32\n",
    "\n",
    "By packing groups of eight 4-bit weights into a single int32, we can store a quantized model more efficiently on disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from compressed_tensors.quantization import (\n",
    "    QuantizationConfig,\n",
    "    QuantizationStatus,\n",
    "    apply_quantization_config,\n",
    "    compress_quantized_weights\n",
    ")\n",
    "from compressed_tensors.compressors import ModelCompressor\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator\n",
    "from datasets import load_dataset\n",
    "from torch.utils.data import RandomSampler\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c883cdc8ecd04866bd01d61796b81c26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32b18b14b6774ce7b61d2854a1ed5f49",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "370c6d18521a4b65833a411728be1ed7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(32000, 2048)\n",
       "    (layers): ModuleList(\n",
       "      (0-21): 22 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (k_proj): Linear(in_features=2048, out_features=256, bias=False)\n",
       "          (v_proj): Linear(in_features=2048, out_features=256, bias=False)\n",
       "          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)\n",
       "          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)\n",
       "          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)\n",
       "        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm((2048,), eps=1e-05)\n",
       "    (rotary_emb): LlamaRotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load a dense, unquantized tiny llama model\n",
    "device = \"cuda:0\"\n",
    "model_name = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=\"auto\")\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following quantization config will be used to quantize all of the Linear layers to 4 bits, excluding the lm_head layer. \n",
    "\n",
    "The `format` argument is set to `pack-quantized`, indicating that when the model is saved we should use the `PackedQuantizationCompressor` which will pack every eight 4-bit weights into an `int32`. \n",
    "\n",
    "This will give us a compression ratio of 8x on each Linear layer compared to the unquantized `float32` representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "quantization_config_dict = {\n",
    "\t\"quant_method\": \"sparseml\",\n",
    "\t\"format\": \"pack-quantized\",\n",
    "\t\"global_compression_ratio\": None,\n",
    "\t\"config_groups\": {\n",
    "        \"group_1\": {\n",
    "            \"weights\": {\n",
    "                \"num_bits\": 4,\n",
    "                \"type\": \"int\",\n",
    "                \"symmetric\": False,\n",
    "                \"strategy\": \"tensor\"\n",
    "            },\n",
    "            \"targets\": [\"Linear\"]\n",
    "        }\n",
    "    },\n",
    "\t\"ignore\": [\"lm_head\"]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup the loaded model for quantization calibration\n",
    "\n",
    "config = QuantizationConfig(**quantization_config_dict)\n",
    "config.quantization_status = QuantizationStatus.CALIBRATION\n",
    "apply_quantization_config(model, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a dataloader of calibration data\n",
    "\n",
    "dataset = load_dataset(\"ptb_text_only\", trust_remote_code=True)[\"train\"]\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"sentence\"], padding=False, truncation=True, max_length=1024)\n",
    "\n",
    "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
    "\n",
    "data_loader = DataLoader(\n",
    "    tokenized_dataset, batch_size=1, collate_fn=DefaultDataCollator(), sampler=RandomSampler(tokenized_dataset)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running calibration: 512it [00:33, 15.42it/s]\n"
     ]
    }
   ],
   "source": [
    "# calibrate scale and zero points for quantization using a small amount of train data\n",
    "num_calibration_samples = 512\n",
    "\n",
    "with torch.no_grad():\n",
    "    for idx, sample in tqdm(enumerate(data_loader), desc=\"Running calibration\"):\n",
    "        sample = {key: value.to(device) for key,value in sample.items()}\n",
    "        _ = model(**sample)\n",
    "\n",
    "        if idx >= num_calibration_samples:\n",
    "            break\n",
    "\n",
    "# freeze scale and zero points after calibration\n",
    "# model.apply(freeze_module_quantization)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After running calibration, each quantized layer will have a new scale and zero_point parameter as shown below.\n",
    "\n",
    "Notice that at this point, the weight itself is still a floating point and has not been quantized. \n",
    "\n",
    "To convert the weights to an integer type, we need to apply the `compress_quantized_weights` function. After compressing the weights, a forward pass of the model can no longer be run in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scale: tensor([17296.], device='cuda:4', dtype=torch.float16), Zero Point: tensor([0], device='cuda:4', dtype=torch.int8)\n",
      "Weight min: -1.587890625 max: 1.0283203125 dtype: torch.float16\n"
     ]
    }
   ],
   "source": [
    "state_dict = model.state_dict()\n",
    "example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n",
    "scale = state_dict[example_layer + \"_scale\"]\n",
    "zero_point = state_dict[example_layer + \"_zero_point\"]\n",
    "weight = state_dict[example_layer]\n",
    "print(f\"Scale: {scale}, Zero Point: {zero_point}\")\n",
    "print(f\"Weight min: {torch.min(weight)} max: {torch.max(weight)} dtype: {weight.dtype}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scale: tensor([17296.], device='cuda:4', dtype=torch.float16), Zero Point: tensor([0], device='cuda:4', dtype=torch.int8)\n",
      "Weight min: 0 max: 0 dtype: torch.int8\n"
     ]
    }
   ],
   "source": [
    "# convert quantized weights to integers\n",
    "model.apply(compress_quantized_weights)\n",
    "\n",
    "state_dict = model.state_dict()\n",
    "example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n",
    "scale = state_dict[example_layer + \"_scale\"]\n",
    "zero_point = state_dict[example_layer + \"_zero_point\"]\n",
    "weight = state_dict[example_layer]\n",
    "print(f\"Scale: {scale}, Zero Point: {zero_point}\")\n",
    "print(f\"Weight min: {torch.min(weight)} max: {torch.max(weight)} dtype: {weight.dtype}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After compressing the quantized model, the weight matrix has a range of int4 but is stored in an int8. \n",
    "\n",
    "We can further compress the model on disk using the `pack-quantized` format we specified in the config. This compression format will pack the int4 weights into int32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compression format: pack-quantized\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Quantized Compression: 100%|██████████| 509/509 [00:03<00:00, 153.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of the model's weights on disk using safetensors: 712.23 MB\n"
     ]
    }
   ],
   "source": [
    "# apply compression and save the model to disk\n",
    "\n",
    "output_dir = \"./ex_llama1.1b_w4a16_packed_quantize\"\n",
    "compression_format = config.format\n",
    "print(f\"Compression format: {compression_format}\")\n",
    "\n",
    "compressor = ModelCompressor(quantization_config=config)\n",
    "compressed_state_dict = compressor.compress(model)\n",
    "model.save_pretrained(output_dir, state_dict=compressed_state_dict)\n",
    "compressor.update_config(output_dir)\n",
    "\n",
    "compressed_size_on_disk_mb = os.path.getsize(os.path.join(output_dir, \"model.safetensors\")) / 1024 / 1024\n",
    "print(f\"Size of the model's weights on disk using safetensors: {compressed_size_on_disk_mb:.2f} MB\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
